# -*- coding: utf-8 -*-
"""
Created on Mon Jul 20 22:50:08 2020

@author: 81283
"""
from policy.value import q_dict
from policy.value import q_value_init      # q值表格初始化函数
from policy.q_and_sarsa import q_learning, sarsa_learning  # 进行一步q_learning， sarsa的函数
from policy.q_and_sarsa import do_final_reward  # 最终确定性策略下的奖励
from environment.env import x_scope, y_scope, actions
import numpy as np


def q2matrix():
    """将epsilon-greedy策略转换为greedy策略，动作值q转化为状态值v"""
    # v_dict=dict()
    q_array=np.zeros((x_scope, y_scope, len(actions)))
    

    for ((state_i, state_j), action), q in q_dict.items():
        # v_dict[state]=max(v_dict.setdefault(state, 0.0), q)
        if action=="up":
            a=0
        elif action=="down":
            a=1
        elif action=="left":
            a=2
        elif action=="right":
            a=3
        q_array[state_i][state_j][a]=q
    return q_array


def q2v():
    """将epsilon-greedy策略转换为greedy策略，动作值q转化为状态值v"""
    # v_dict=dict()
    q_array=np.zeros((x_scope, y_scope, len(actions)))
    

    for ((state_i, state_j), action), q in q_dict.items():
        # v_dict[state]=max(v_dict.setdefault(state, 0.0), q)
        if action=="up":
            a=0
        elif action=="down":
            a=1
        elif action=="left":
            a=2
        elif action=="right":
            a=3
        q_array[state_i][state_j][a]=q
    return q_array.max(axis=2)
    

if __name__=="__main__":
    no_cliff=False # 设置是否有悬崖设定
    trials=10000 #实验次数
    episodes=1000 #回合次数
    reward_trials_q=[] #q_learing记录不同实验的reward
    reward_trials_sarsa=[] #sarsa_learing记录不同实验的reward

    v_q=[]         # q_learning 的状态值记录
    v_sarsa=[]     # sarsa      的状态值记录
    q_q=[]         # q_learning 的动作值值记录
    q_sarsa=[]     # sarsa      的动作值值记录

    final_reward_q=[]           #q学习     每次实验确定性策略找到终点情况的回报
    final_visit_q=[]             #q学习     每次实验确定性策略找到终点历史访问情况
    final_reward_sarsa=[]       #sarsa学习    每次实验确定性策略找到终点情况的回报
    final_visit_sarsa=[]         #sarsa学习    每次实验确定性策略找到终点历史访问情况
    
    no_final_reward_q=[]        #q学习   每次实验确定性策略找不到终点情况的回报
    no_final_visit_q=[]         #q学习   每次实验确定性策略找不到终点历史访问情况
    no_final_reward_sarsa=[]    #sarsa学习    每次实验确定性策略找不到终点情况的回报
    no_final_visit_sarsa=[]     #sarsa学习    每次实验确定性策略找不到终点历史访问情况


    # q_learning 学习部分
    for trial in range( trials ):#进行10000次实验
        # 记录一次实验中每个片段的即时奖励总和（这里不考虑折扣）
        reward_episodes= []
        q_episode_q=[]
        
        # 每次实验前对q值表格进行初始化
        q_value_init() 
        
        
        #每次实验进行episodes个片段的训练
        for episode in range(episodes):
            reward_steps=0 # 记录每step的reward
            state=(0, 0)
        
            while True:
                state, reward, end_flag=q_learning(state, no_cliff)
                reward_steps+=reward
                
                if end_flag:
                    reward_episodes.append(reward_steps)
                    # print("episode :%d   reward :%d "%(episode, reward_steps))
                    break
            q_episode_q.append( q2matrix() )     # 加入每次回合后的动作q值

        # 一次实验结束后的工作
        reward_trials_q.append(reward_episodes) #加入500次回合的奖励
        #v_q.append( q2v() )         # 加入500次回合后的状态V值
        q_q.append(q_episode_q)
        
        reward_end, solved, visit_hist=do_final_reward(no_cliff)
        if solved==True:
            final_reward_q.append(reward_end) #加入500次回合后确定性策略的回报
            final_visit_q.append(visit_hist)
        else:
            no_final_reward_q.append(reward_end)
            #no_final_visit_q.append(visit_hist)

        print("q_learning部分 第%d次实验 over"%trial)
    
    
    # sarsa_learning 学习部分
    for trial in range( trials ):#进行10000次实验
        # 记录一次实验中每个片段的即时奖励总和（这里不考虑折扣）
        reward_episodes= []
        q_episode_sarsa=[]
        
        # 每次实验前对q值表格进行初始化
        q_value_init() 
        
        
        #每次实验进行episodes个片段的训练
        for episode in range(episodes):
            reward_steps=0 # 记录每step的reward
            state=(0, 0)
        
            while True:
                state, reward, end_flag=sarsa_learning(state, no_cliff)
                reward_steps+=reward
                
                if end_flag:
                    reward_episodes.append(reward_steps)
                    # print("episode :%d   reward :%d "%(episode, reward_steps))
                    break
            q_episode_sarsa.append( q2matrix() )     # 加入每次回合后的动作q值

        # 一次实验结束后的工作
        reward_trials_sarsa.append(reward_episodes)
        #v_sarsa.append( q2v() )
        q_sarsa.append(q_episode_sarsa)

        
        reward_end, solved, visit_hist=do_final_reward(no_cliff)
        if solved==True:
            final_reward_sarsa.append(reward_end) #加入500次回合后确定性策略的回报
            final_visit_sarsa.append(visit_hist)
        else:
            no_final_reward_sarsa.append(reward_end)
            #no_final_visit_sarsa.append(visit_hist)

        print("sarsa_learning部分 第%d次实验 over"%trial)    
    
    
    
    
    #绘图部分
    import matplotlib.pyplot as plt

    points_q=np.array(reward_trials_q).mean(axis=0)
    points_sarsa=np.array(reward_trials_sarsa).mean(axis=0)

    plt.plot(points_sarsa, 'b', label="sarsa-learning")
    plt.plot(points_q, 'r', label="q-learning")
    plt.legend(loc = 0)
    plt.xlabel('episode')
    plt.ylabel('reward sum per episode')
    plt.xlim(0, episodes)
    plt.ylim(-100, 0)
    plt.title("SARSA & Q-LEARNING")    
    plt.show()
    
    
    
    # print("q学习10000次实验后 状态值 矩阵为：")
    # avg_v_q = np.flipud( np.mean(v_q, axis=0).T )
    # print( avg_v_q )
    
    print("q学习后可以找到终点的确定性策略的回报：")
    print("均值 %f , 方差 %f , 中位数 %f"%(np.mean(final_reward_q), 
                    np.var(final_reward_q), np.median(final_reward_q)))
    print("q学习后不能找到终点的确定性策略的回报：")
    print("均值 %f , 方差 %f , 中位数 %f"%(np.mean(no_final_reward_q), 
                    np.var(no_final_reward_q), np.median(no_final_reward_q)))
      
    
    
    
    
    # print("sarsa学习10000次实验后 状态值 矩阵为：")
    # avg_v_sarsa = np.flipud( np.mean(v_sarsa, axis=0).T )
    # print( avg_v_sarsa )
 
    print("sarsa学习后可以找到终点的确定性策略的回报：")
    print("均值 %f , 方差 %f , 中位数 %f"%(np.mean(final_reward_sarsa), 
                    np.var(final_reward_sarsa), np.median(final_reward_sarsa)))
    print("sarsa学习后不能找到终点的确定性策略的回报：")
    print("均值 %f , 方差 %f , 中位数 %f"%(np.mean(no_final_reward_sarsa), 
                    np.var(no_final_reward_sarsa), np.median(no_final_reward_sarsa)))
      
       
    
    x_list=[]
    y_list=[]    
    for x, y in zip(q_q, q_sarsa):
        a_q,b_q=np.array(x[:-1]), np.array(x[1:])
        a_sarsa, b_sarsa = np.array(y[:-1]), np.array(y[1:])
        delta_q=np.sum(np.abs(b_q - a_q), axis=(1, 2, 3))
        delta_sarsa=np.sum(np.abs(b_sarsa - a_sarsa), axis=(1, 2, 3))
        x_list.append(delta_q)
        y_list.append(delta_sarsa)
    q_point=np.array(x_list).mean(axis=0)
    sarsa_point=np.array(y_list).mean(axis=0)

    
    plt.plot(q_point, 'b', label="sarsa-learning")
    plt.plot(sarsa_point, 'r', label="q-learning")
    plt.legend(loc = 0)
    plt.xlabel('episodes')
    plt.ylabel('q_value_sum_dynamic')
    plt.xlim(0, episodes)
    # plt.ylim(-100, 0)
    plt.title("SARSA & Q-LEARNING")    
    plt.show()
    
    
    